"""
Phase 2: Baseline Regressions — Pension Fund Home Bias
=======================================================
Tests:
  (a) Z → pension spending/GDP (does aging predict pension size?)
  (b) Z → gross external assets (does aging drive cross-border diversification?)
  (c) Z → portfolio composition (debt vs equity vs FDI shares)
  (d) Z × pension → external assets (pension as mechanism)
  (e) OECD vs non-OECD; income terciles

The key hypothesis: aging → pension AUM growth → home market insufficient
→ cross-border portfolio diversification → explains why Z affects
portfolio flows but not FDI (Paper 2).

Output: output/tables/baseline_pension.md, portfolio_composition.md, mechanism.md
"""

import sys
from pathlib import Path

import numpy as np
import pandas as pd

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/pension_home_bias")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

PROCESSED_DIR = PROJECT_DIR / "data" / "processed"
TABLES_DIR = PROJECT_DIR / "output" / "tables"

OECD_38 = [
    "AUS", "AUT", "BEL", "CAN", "CHL", "COL", "CRI", "CZE", "DNK", "EST",
    "FIN", "FRA", "DEU", "GRC", "HUN", "ISL", "IRL", "ISR", "ITA", "JPN",
    "KOR", "LVA", "LTU", "LUX", "MEX", "NLD", "NZL", "NOR", "POL", "PRT",
    "SVK", "SVN", "ESP", "SWE", "CHE", "TUR", "GBR", "USA",
]


def stars(p):
    if p < 0.01: return '***'
    if p < 0.05: return '**'
    if p < 0.10: return '*'
    return ''


def run_model(df, dep_var, regressors, label):
    regressors = [r for r in regressors if r in df.columns]
    if dep_var not in df.columns:
        print(f"  [{label}] {dep_var} missing — skipping")
        return None
    sub = df.dropna(subset=[dep_var] + regressors).copy()
    if len(sub) < 50:
        print(f"  [{label}] Insufficient obs ({len(sub)}) — skipping")
        return None
    gls = PanelGLS()
    gls.fit(sub[dep_var].values, sub[regressors].values,
            sub['iso3'].values, sub['year'].values)
    print(f"\n  [{label}]  N={gls.n_obs}, countries={gls.n_countries}, R²={gls.r_squared:.4f}")
    results = {
        'label': label, 'dep_var': dep_var,
        'n_obs': gls.n_obs, 'n_countries': gls.n_countries,
        'r_squared': gls.r_squared,
    }
    for i, name in enumerate(regressors):
        results[f'coef_{name}'] = gls.beta[i]
        results[f'se_{name}'] = gls.se[i]
        results[f'p_{name}'] = gls.pvalues[i]
        sig = stars(gls.pvalues[i])
        print(f"    {name:<30} {gls.beta[i]:>10.4f} ({gls.se[i]:.4f}) {sig}")
    return results


def build_table(results, key_vars, notes, filename, title):
    if not results:
        return
    md = [f"# {title}\n"]
    md.append("| Model | Dep Var | N | Countries | R² |")
    md.append("|---|---|---|---|---|")
    for r in results:
        md.append(f"| {r['label']} | {r['dep_var']} | {r['n_obs']:,} "
                  f"| {r['n_countries']} | {r['r_squared']:.3f} |")
    md.append("\n## Key Coefficients\n")
    md.append("| Model | Variable | Coef | SE | p-value | Sig |")
    md.append("|---|---|---|---|---|---|")
    for r in results:
        for var in key_vars:
            ckey = f'coef_{var}'
            if ckey in r:
                p = r[f'p_{var}']
                md.append(f"| {r['label']} | {var} | {r[ckey]:.4f} "
                          f"| {r[f'se_{var}']:.4f} | {p:.4f} | {stars(p)} |")
    md.append(f"\n*{notes}*")
    out = TABLES_DIR / filename
    out.write_text('\n'.join(md))
    print(f"\n  Saved: {out}")


def main():
    print("=" * 70)
    print("PHASE 2: Baseline Regressions — Pension Fund Home Bias")
    print("=" * 70)

    df = pd.read_csv(PROCESSED_DIR / "pension_panel.csv")
    print(f"Panel: {df['iso3'].nunique()} countries, {len(df):,} obs")

    demo_vars = ['Z_1', 'Z_2', 'Z_3']
    age_vars = ['old_dep', 'youth_dep']
    controls = ['rgdp_growth', 'fiscal_bal_gdp', 'kaopen']
    controls = [c for c in controls if c in df.columns]
    controls_inc = controls + (['log_gdp_pc'] if 'log_gdp_pc' in df.columns else [])

    oecd = df[df['oecd'] == 1].copy()
    non_oecd = df[df['oecd'] == 0].copy()

    # ═══════════════════════════════════════════════════════════════════
    # PART A: Z → PENSION SIZE (first stage)
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART A: Z → PENSION SIZE (first stage)")
    print("=" * 70)

    a_results = []

    if 'pension_spending_gdp' in df.columns:
        r = run_model(df, 'pension_spending_gdp', demo_vars + controls_inc,
                      "A1: Z → pension_spending")
        if r: a_results.append(r)

        r = run_model(df, 'pension_spending_gdp', age_vars + controls_inc,
                      "A2: age ratios → pension")
        if r: a_results.append(r)

        r = run_model(oecd, 'pension_spending_gdp', demo_vars + controls_inc,
                      "A3: OECD Z → pension")
        if r: a_results.append(r)

        r = run_model(non_oecd, 'pension_spending_gdp', demo_vars + controls_inc,
                      "A4: non-OECD Z → pension")
        if r: a_results.append(r)

    build_table(a_results, demo_vars + age_vars,
                "First stage: demographics predict pension system size",
                "first_stage_pension.md",
                "First Stage: Demographics → Pension Size")

    # ═══════════════════════════════════════════════════════════════════
    # PART B: Z → CROSS-BORDER DIVERSIFICATION
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART B: Z → CROSS-BORDER DIVERSIFICATION")
    print("=" * 70)

    b_results = []

    # B1: Z → gross external assets/GDP
    r = run_model(df, 'gross_assets_gdp', demo_vars + controls_inc,
                  "B1: Z → gross_assets/GDP")
    if r: b_results.append(r)

    # B2: Z → portfolio assets (debt + equity)
    if 'portfolio_assets_gdp' in df.columns:
        r = run_model(df, 'portfolio_assets_gdp', demo_vars + controls_inc,
                      "B2: Z → portfolio_assets/GDP")
        if r: b_results.append(r)

    # B3: Z → diversification proxy
    if 'diversification' in df.columns:
        r = run_model(df, 'diversification', demo_vars + controls_inc,
                      "B3: Z → diversification")
        if r: b_results.append(r)

    # B4: OECD
    r = run_model(oecd, 'gross_assets_gdp', demo_vars + controls_inc,
                  "B4: OECD Z → gross_assets")
    if r: b_results.append(r)

    # B5: non-OECD
    r = run_model(non_oecd, 'gross_assets_gdp', demo_vars + controls_inc,
                  "B5: non-OECD Z → gross_assets")
    if r: b_results.append(r)

    # B6: age ratios
    r = run_model(df, 'gross_assets_gdp', age_vars + controls_inc,
                  "B6: age ratios → gross_assets")
    if r: b_results.append(r)

    build_table(b_results, demo_vars + age_vars,
                f"Controls: {', '.join(controls_inc)}",
                "diversification.md",
                "Demographics → Cross-Border Diversification")

    # ═══════════════════════════════════════════════════════════════════
    # PART C: Z → PORTFOLIO COMPOSITION
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART C: Z → PORTFOLIO COMPOSITION")
    print("=" * 70)

    c_results = []

    # Test: aging tilts external assets toward debt (pension-consistent)
    # and away from FDI (illiquid, not pension-appropriate)

    for dep, label in [
        ('debt_share_assets', 'C1: Z → debt share'),
        ('fdi_share_assets', 'C2: Z → FDI share'),
        ('debt_assets_gdp', 'C3: Z → debt assets/GDP'),
        ('port_eq_assets_gdp', 'C4: Z → equity assets/GDP'),
        ('fdi_assets_gdp', 'C5: Z → FDI assets/GDP'),
    ]:
        r = run_model(df, dep, demo_vars + controls_inc, label)
        if r: c_results.append(r)

    # OECD debt share
    r = run_model(oecd, 'debt_share_assets', demo_vars + controls_inc,
                  "C6: OECD Z → debt share")
    if r: c_results.append(r)

    # non-OECD debt share
    r = run_model(non_oecd, 'debt_share_assets', demo_vars + controls_inc,
                  "C7: non-OECD Z → debt share")
    if r: c_results.append(r)

    build_table(c_results, demo_vars,
                "Portfolio composition: does aging tilt toward debt and away from FDI?",
                "portfolio_composition.md",
                "Demographics → Portfolio Composition (Debt vs Equity vs FDI)")

    # ═══════════════════════════════════════════════════════════════════
    # PART D: PENSION AS MECHANISM
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART D: PENSION AS MECHANISM")
    print("=" * 70)

    d_results = []

    # D1: pension → gross assets (second stage)
    if 'pension_spending_gdp' in df.columns:
        r = run_model(df, 'gross_assets_gdp',
                      ['pension_spending_gdp'] + controls_inc,
                      "D1: pension → gross_assets")
        if r: d_results.append(r)

    # D2: Z → gross assets controlling for pension (mediation)
    if 'pension_spending_gdp' in df.columns:
        r = run_model(df, 'gross_assets_gdp',
                      demo_vars + controls_inc + ['pension_spending_gdp'],
                      "D2: Z → gross_assets | pension")
        if r: d_results.append(r)

    # D3: Z × pension → gross assets
    z_pension_int = [f'{zv}_x_pension' for zv in demo_vars
                     if f'{zv}_x_pension' in df.columns]
    if z_pension_int and 'pension_spending_gdp' in df.columns:
        r = run_model(df, 'gross_assets_gdp',
                      demo_vars + controls_inc + ['pension_spending_gdp'] + z_pension_int,
                      "D3: Z×pension → gross_assets")
        if r: d_results.append(r)

    # D4: pension → debt share
    if 'pension_spending_gdp' in df.columns:
        r = run_model(df, 'debt_share_assets',
                      ['pension_spending_gdp'] + controls_inc,
                      "D4: pension → debt_share")
        if r: d_results.append(r)

    # D5: pension → FDI share (expect null or negative)
    if 'pension_spending_gdp' in df.columns:
        r = run_model(df, 'fdi_share_assets',
                      ['pension_spending_gdp'] + controls_inc,
                      "D5: pension → FDI_share")
        if r: d_results.append(r)

    # Attenuation test
    if len(d_results) >= 2:
        # Compare Z coef with and without pension control
        base_r = None
        pension_r = None
        for r in b_results:
            if r['label'] == "B1: Z → gross_assets/GDP":
                base_r = r
        for r in d_results:
            if r['label'] == "D2: Z → gross_assets | pension":
                pension_r = r
        if base_r and pension_r:
            z1_base = base_r.get('coef_Z_1', None)
            z1_pension = pension_r.get('coef_Z_1', None)
            if z1_base and z1_pension and z1_base != 0:
                att = (1 - z1_pension / z1_base) * 100
                print(f"\n  ★ Z₁ attenuation (base → + pension): {att:.1f}%")

    key_d = ['pension_spending_gdp'] + demo_vars + z_pension_int
    build_table(d_results, key_d,
                "Pension mechanism: does pension size mediate Z → external assets?",
                "mechanism.md",
                "Pension Mechanism: Demographics → Pension → External Assets")

    # ═══════════════════════════════════════════════════════════════════
    # PART E: LAGGED & ROBUSTNESS
    # ═══════════════════════════════════════════════════════════════════
    print("\n" + "=" * 70)
    print("PART E: LAGGED & ROBUSTNESS")
    print("=" * 70)

    e_results = []

    lag_vars = ['Z_1_lag5', 'Z_2_lag5', 'Z_3_lag5']

    r = run_model(df, 'gross_assets_gdp', lag_vars + controls_inc,
                  "E1: Z_lag5 → gross_assets")
    if r: e_results.append(r)

    diff_vars = ['d_Z_1', 'd_Z_2', 'd_Z_3']
    r = run_model(df, 'gross_assets_gdp', diff_vars + controls_inc,
                  "E2: ΔZ → gross_assets")
    if r: e_results.append(r)

    if 'oadr_plus20' in df.columns:
        r = run_model(df, 'gross_assets_gdp', ['oadr_plus20'] + controls_inc,
                      "E3: OADR+20 → gross_assets")
        if r: e_results.append(r)

    # Income terciles
    for group in ['low', 'mid', 'high']:
        sub = df[df['income_group'] == group].copy()
        r = run_model(sub, 'gross_assets_gdp', demo_vars + controls,
                      f"E4: {group}-income Z → gross_assets")
        if r: e_results.append(r)

    build_table(e_results, demo_vars + lag_vars + diff_vars + ['oadr_plus20'],
                "Robustness checks",
                "robustness.md",
                "Robustness: Demographics → External Assets")

    print("\n" + "=" * 70)
    print("Phase 2 complete.")
    print("=" * 70)


if __name__ == "__main__":
    main()
